{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## We've already seen how to implement a linear regression where we used a single variable to predict the value of another related variable. In the case where we want to predict the value of a variable using more than one variable as input then we need to use matrices.\n", "\n", "In this notebook we'll implement a multivariate linear regression. Here we'll only cover continuous covariate variables but the method works identically if we used categorical covariates - it just requires us to do some extra processing before fitting the model!" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "from sklearn.model_selection import train_test_split" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Generate data for multivariate regression" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[ 6.00804119 7.32002909 -9.62139603 4.97830887 8.92353483]\n" ] } ], "source": [ "n = 1000 #Number of observations in the training set\n", "p = 5 #Number of parameters, including intercept\n", "\n", "#Assign True parameters to be estimated\n", "beta = np.random.uniform(-10, 10, p) #Randomly initialise true parameters\n", "print(beta)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "X = np.random.uniform(0,10,(n,(p-1))) \n", "X0 = np.array([1]*n).reshape((n,1)) #Columns for intercept\n", "\n", "X = np.concatenate([X0,X], axis = 1) #Join intercept to other variables to form feature matrix\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "Y = np.matmul(X,beta) + np.random.normal(0,10,n) #Linear combination of the features plus a normal error term" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "#Concatenate to create dataframe\n", "\n", "dataFeatures = pd.DataFrame(X)\n", "dataFeatures.columns = [f'X{i}' for i in range(p)]\n", "\n", "dataTarget = pd.DataFrame(Y)\n", "dataTarget.columns = ['Y']\n", "\n", "data = pd.concat([dataFeatures, dataTarget], axis = 1)\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
X0X1X2X3X4Y
01.02.1553609.9212190.0443250.441540-75.531362
11.08.6661797.3909772.8593170.21806329.703580
21.03.6186146.2465150.0757298.56494869.347375
31.00.4250445.0638640.9746872.666221-2.618186
41.04.6013842.4309281.1191960.16534842.026702
\n", "
" ], "text/plain": [ " X0 X1 X2 X3 X4 Y\n", "0 1.0 2.155360 9.921219 0.044325 0.441540 -75.531362\n", "1 1.0 8.666179 7.390977 2.859317 0.218063 29.703580\n", "2 1.0 3.618614 6.246515 0.075729 8.564948 69.347375\n", "3 1.0 0.425044 5.063864 0.974687 2.666221 -2.618186\n", "4 1.0 4.601384 2.430928 1.119196 0.165348 42.026702" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# The Algebra" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To fit a linear regression for a set of features $X$ and a set of targets $Y$, we compute the model parameters as:\n", "\n", "$$\\hat \\beta = (X^TX)^{-1}X^Ty$$\n", "\n", "$\\hat \\beta$ is a $p \\times 1$ vector where each element of the vector corresponds to the estimate of the true parameter which generated the data\n", "\n", "\n", "This estimator is derived using the same ideas as for the single variable case but we have to work with matrices rather than vectors - See [this link](http://home.iitk.ac.in/~shalab/regression/Chapter3-Regression-MultipleLinearRegressionModel.pdf) for a detailed derivation. " ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "class LinearRegressionMultivariate:\n", " \n", " def __init__(self, data, target, features, trainTestRatio = 0.9):\n", " #data - a pandas dataset \n", " #target - the name of the pandas column which contains the true labels\n", " #features - A list containing the names of the columns which we will use to do the regression\n", " #trainTestRatio - the proportion of the entire dataset which we'll use for training\n", " # - the rest will be used for testing\n", " \n", " self.target = target\n", " self.features = features \n", " \n", " #Split up data into a training and testing set\n", " self.train, self.test = train_test_split(data, test_size=1-trainTestRatio)\n", " \n", " \n", " \n", " def fitLR(self):\n", " #Fit a linear regression to the training data\n", " #Useful functions: np.matmul multiplies two matrices together, \n", " #np.transpose returns the transposition of a matrix\n", " #np.linalg.inv returns the inverse of a square matrix\n", " \n", " \n", " \n", " \n", " #Rename train and test data to make the calculation less unpleasant to look at\n", " #Change the data type from pandas dataframe to numpy array\n", " X = np.array(self.train[self.features])\n", " y = np.array(self.train[self.target])\n", " \n", " \n", " #self.betaHat should contain the estimates for the parameters\n", " #Simply a case of implementing the equation above - make sure the matrix dimensions for each term matches up!\n", " self.betaHat = #...\n", " \n", " return 0 #We've saved the parameter values as part of the class now\n", " \n", " def predict(self,x):\n", " #Given a vector (or matrix) of new observations x, predict the corresponding target values\n", " \n", " #This can be done by multiplying x by betaHat - make sure the predictions match up!\n", " \n", " pass\n", " " ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "myModel = LinearRegressionMultivariate(data, 'Y', [f'X{i}' for i in range(p)])" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "myModel.fitLR()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Print the model estimates - there should be the right number (p) of them!" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[ 6.07510796 7.41825296 -9.70828959 4.82094549 9.00624211]\n", "(5,)\n" ] } ], "source": [ "print(myModel.betaHat)\n", "print(myModel.betaHat.shape) #==p" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Predict values for the test set" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "testPred = myModel.predict(np.array(myModel.test[myModel.features]))" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX8AAAD8CAYAAACfF6SlAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi40LCBodHRwOi8vbWF0cGxvdGxpYi5vcmcv7US4rQAAIABJREFUeJzt3Xl8U1X6x/HPaRugbC3IXkTQQRREQfnhgrsobiiiuCszU0AUFBRRmHFXtMq4joBsio4I6oC1LAoKOCgqUGTfFEGFglCEAkKXtDm/P5pC2iZtuqRJk+/79eLV9vbe5DSv8Nyb5z7nOcZai4iIRJaoYA9ARESqnoK/iEgEUvAXEYlACv4iIhFIwV9EJAIp+IuIRKBSg78x5nhjzCJjzEZjzHpjzBD39obGmC+MMT+5vzbwOGakMWaLMWazMaZHIP8AEREpO1Nanb8xpjnQ3Fr7gzGmHrAC6AX8FdhnrU0yxowAGlhrHzXGtAemAV2BFsCXwMnW2rwA/h0iIlIGpV75W2t3WWt/cH9/CNgIJADXA++6d3uX/BMC7u3TrbXZ1tptwBbyTwQiIhIiYsqyszGmNdAZWAo0tdbugvwThDGmiXu3BOB7j8N2uLf51KhRI9u6deuyDEVEJOKtWLFir7W2cXmO9Tv4G2PqAjOAodbag8YYn7t62VYst2SMGQAMAGjVqhWpqan+DkVERABjzK/lPdavah9jjIP8wD/VWjvTvXm3+35AwX2BPe7tO4DjPQ5vCews+pjW2gnW2i7W2i6NG5frxCUiIuXkT7WPASYDG621r3j8KgXo6/6+L/Cpx/ZbjTE1jTFtgLbAssobsoiIVJQ/aZ9uwF3AWmPMKve2fwBJwEfGmETgN6APgLV2vTHmI2ADkAsMUqWPiEhoKTX4W2u/wXseH+AyH8eMAkZVYFwiIhJAmuErIhKBFPxFRCKQgr+ISARS8BcRqWJZuVk8tvAxfvzjx6CNQcFfRKQKLf51MWe8dQajvh7FrM2zgjYOBX8RkSpwMPsg986+l4umXEROXg7z75zPsPOGBW08ZertIyIiZTf7x9ncO+de0g6m8eA5D/LsJc/yxfoMuiUtZGdGJi3iYxneox29OpfYBq1SKfiLiATInsN7GPL5EKavm06Hxh34b+J/Obvl2SSvTGPkzLVkOvPnv6ZlZDJy5lqAKjsBKO0jIlLJrLX8Z/V/OHXMqczYMIOnL36aH+75gbNbng3A6Hmbjwb+ApnOPEbP21xlY9SVv4hIJfo141fumX0P836exzktz2FSz0l0aNKh0D47MzK9HutreyAo+IuIVII8Vx5jl49l5IKRALx+5esM+r9BREdFF9u3RXwsaV4CfYv42ICPs4DSPiIiFbQhfQMXvHMBD3z+AOe3Op91963jgbMf8Br4AYb3aEeso/DvYh3RDO/RriqGC+jKX0Sk3HLyckj6JolRX4+ibo26vNfrPe48/U5KWOwKOHZTd/S8zar2ERGpTpalLSMxJZF1e9ZxS4dbeOOqN2hSp0npB7r16pxQpcG+KAV/EZEyOJxzmMcXPc7rS1+ned3mpNyaQs92PYM9rDJT8BcR8dOXW79kwKwBbMvYxsCzBpLUPYm4WnHBHla5KPiLiJRif+Z+hs0fxjur3qFtw7Z81fcrLmp9UbCHVSEK/iIiPlhrmbFxBoPnDmbvkb2M6DaCJy56glhH1ZVkBoqCv4iIFzsP7WTQ3EEkb0qmc7POfHbHZ3Ru3jnYw6o0Cv4iIh6stUxeOZmH5z9Mdl42SZclMey8YcREhVe4DK+/RkSkArbs28KAWQNY9MsiLjrhIib2nEjb49oGe1gBoeAvIhEv15XLq9+9yhNfPUGN6BqMv3Y8/c7sR5QJ3yYICv4iEtFW/76axJREVuxawXXtrmPs1WNJqB+8yVdVRcFfRCJSVm4Wz/7vWV769iUaxjbkw5s+pE/7PqW2ZggXCv4iEnG+/vVr+s/qz+Y/NtP3jL68fMXLHFf7uGAPq0op+ItIxDiYfZCRX45kbOpYTog7gc/v+Jwef+kR7GEFhYK/iESEOT/OYeCcgaQdTGPI2UN47tLnqFujbrCHFTQK/iIS1tIPpzPk8yFMWzeN9o3b83Hix5zT8pxgDyvoFPxFJCxZa5m6dipDPx/KweyDPHnRk4w8fyQ1Y2oGe2ghQcFfRMLObwd+Y+DsgXy25TPOTjibyddNLraObqRT8BeRsOGyrqPr6Lqsi9d6vMbgroN9LqcYyRT8RSQsbEzfSL9Z/fh2+7dccdIVjL92PK3jWwd7WCFLwV9EqrWcvBxeWvISzy5+ljqOOky5fgp3n3F3xEzWKi8FfxGptpanLScxJZG1e9Zyc4ebeePKN2hat2mwh1UtKPiLSLVzOOcwTyx6gteWvkazus349NZPua7ddSSvTGP0vIXszMikRXwsw3u0C+oi6aGs1JZ1xpi3jTF7jDHrPLY9ZYxJM8ascv+72uN3I40xW4wxm40xkTl1TkQCZsHWBXQc15FXvn+Ffp37seG+DUcD/8iZa0nLyMQCaRmZjJy5luSVacEeckjyp1/pFOBKL9tftdZ2cv+bC2CMaQ/cCnRwHzPWGKPb7CJSYfsz95P4aSLd/9Od6Khovur7FeN7jj+6gProeZvJdOYVOibTmcfoeZuDMdyQV2rax1q72BjT2s/Hux6Ybq3NBrYZY7YAXYHvyj1CEQkp+amVzWVKrZTnGE8zNsxg8GeDST+czqPdHuXJi54sto7uzoxMr8f62h7pKrJSwWBjzBp3WqiBe1sCsN1jnx3ubcUYYwYYY1KNManp6ekVGIaIVJXklWkM/+/qQqmV4f9dXWJqpSLpmF2HdtH7w97c9PFNNKvbjGX9l5HUPcnrAuot4r0vqm6BbkkLlf4porzBfxxwEtAJ2AW87N7urbbKensAa+0Ea20Xa22Xxo0bl3MYIlKVnp61Hmde4f/SzjzL07PW+zymPOkYay2Tf5jMqWNOZe5Pc3nhshdY1m8ZZzY/0+cxw3u0I9bhPcus/H9x5ar2sdbuLvjeGDMRmO3+cQdwvMeuLYGd5R6diISU/UecZdoOZU/H/LzvZwbMHsDCbQu58IQLmdhzIicfd3KpYytII42et5k0L49dcMJR9U++cl35G2Oae/x4A1BQCZQC3GqMqWmMaQO0BZZVbIgiUp35SscU3Z7ryuVf3/6LjuM6sjxtOW9d8xaL+i7yK/AX6NU5gSUjLvWaggDl/z2VeuVvjJkGXAw0MsbsAJ4ELjbGdCI/pfMLcA+AtXa9MeYjYAOQCwyy1uZ5e1wRqX7iYx1kZBa/yo+Pdfi8qTu8RztGzlxbKPUT64hmeI92R39es3sNiSmJpO5MpefJPRl7zVha1m9Z7nG2iI/1evXv60QUifyp9rnNy+bJJew/ChhVkUGJSGh66roODP94NU7Xsby/I8pw7RnNCwX4ghw7FE7HFD0xZOVm8dzi53hxyYs0qNWA6TdO5+YON1e4NYM/J5xIpxm+IuI3X4G8pJu6vTonHP3n6ZvfvqH/rP5s2ruJu8+4m1eueKXS1tEt6YQj+RT8RaRMvAXyBz9c5XVfbzl2z3V0W8W14rM7PuPKv3ibR1r545RjFPxFpML8zbHP/WkuA2cPZMfBHTzQ9QFGXTYqotfRDaaKTPISEQG819h75tjTD6dzx8w7uOaDa6hXsx5L/r6E1696XYE/iHTlLyIV5ivHfn2nFkxdM5Wh84ZyIOuA1tENIQr+IlIpiubYtx/YzrXTrmXuT3PpmtCVyddN5rQmpwVxhOJJwV9EKpXLuhi3fBwjFozAZV282uNV7u96v9bRDTEK/iJSaTbt3US/lH4s2b6E7id2Z8K1E2jToE2whyVeKPiLSIU585y8tOQlnln8DHUcdXjn+nfoe0ZfPl21kzvHa2WtUKTgLyIVkrozlcSURNbsXkOf9n1446o3aFa32dFWziXN+pXgUamniJTLEecRHp7/MGdPOpu9R/aSfEsyH/X5iGZ1mwFaWSvU6cpfRMps4baF9J/Vn637t9L/zP68dPlLxNeKL7SPVtYKbQr+IhGmIksqZmRl8PD8h5m8cjJ/afgXFt69kEvaXOJ1X3XWDG1K+4hEkIosqThz40xOHXMqU1ZN4ZHzHmHNwDU+Az+UPutXgktX/iIRpLTum978/ufvDJ47mBkbZ3BG0zOYfdtszmpxVqnPpc6aoU3BXySClCUPb63lnVXvMGz+MDKdmTx/6fM8fN7DOKIdfj+fOmuGLgV/kQjibx5+6/6tDJg1gAXbFnBBqwuY2HMi7RopXRNOlPMXiSCl5eHzXHn87b9P0vaN9izc+i0nxgxl6BlTFfjDkK78RSJISXn4tbvX0nv6XWzJWE1s3v/R0HkfeVmN+ecn64kyUUrfhBljrS19rwDr0qWLTU1NDfYwRCKGZ7lns7hojm/9OZ9sGQOuusTn9Kd23oUYjq2jmxAfy5IRlwZxxOKNMWaFtbZLeY7Vlb9IBPAM9vG1HfyZlYvTZcmK2siKrDf4/sftXNzyRn7+qTdRxBU7XhOzwo+Cv0iYK9pjZ/8RJy6OkOF4j0PRc4i2jWiS/TQ56edzfDyamBUhdMNXJMwVre3PjEplZ81BHIqeQ728a2mRPYZY11nszMjUxKwIoit/kTBXkLLJ4wD7HZM4HLMIh+t4GuW8RC3XqUf3axEfq4lZEUTBXyTMNY+rxU9/zmOfYzwuDhPnvJW43FswHJus5Xl1r4lZkUHBXyQIKtJcrSy2H9hOToMX2JuzgBquthyXM4QatnWhfRJ0dR+RFPxFqlhVLHLisi7Gp47n0S8fJdeVS4OcftTL64mhcD7fQIVLOKvqRCaVS8FfpIqVp7laaTwDcIP6e8muN44N+5bS/cTujL92PHeN/yUgVTxarav6UrWPSBWr7EVOCgLwjoxDZMR8xMqc/mz6Yz2DO73M/Dvnc2KDEwNWxaPVuqovXfmLBJC3lEhlL3Iyet5mMnI380fN13FGbaN2Xjca5gzkh40tMNfnz9INVBWPVuuqvhT8RQLEV0rkxrMSmLEirdAVc3mvwo84j7DuzzEcrJlMNPE0zv4HtV3nAcUDcCCqeLRaV/WltI9IgPhKiSzalM4LvTuSEB+LIb/a5oXeHcscmBdtW8Tp407noGMmdfO60yJr7NHAD1UTgDUprPrSlb9IgJSUEqnIVXhGVgbD5w9n0spJnNjgRJ4+90OmfV2fTCr+SaKsNCms+lLwFwkQf1IiZS2TTN6UzH1z7mP34d08fO7DdGlwD28s+I1MZybRxpBnbZXX7WtSWPWktI9IgJSWEinLYuq///k7fT7uww0f3kDjOo1Z2m8p3RoP5amULUdPMHnWHn18BWMpTanB3xjztjFmjzFmnce2hsaYL4wxP7m/NvD43UhjzBZjzGZjTI9ADVwkVCWvTKNb0kIe/HAVtRxRxMc6vOb2/SmTtNYyZdUU2o9pT8rmFEZdOorU/ql0adFFZZZSIf6kfaYAbwLveWwbASyw1iYZY0a4f37UGNMeuBXoALQAvjTGnGytzUMkAnhrnxzriObVWzoVuxovrUxy2/5t3DP7Hr7Y+gXntzqfiT0nckqjU/w+XqQkpV75W2sXA/uKbL4eeNf9/btAL4/t06212dbabcAWoGsljVUk5JV2NV7wqaDNiDlEGePtIWgeV4NXv3uV08adxnc7vmPM1WP431//Vyjwg+9qHpVZij/Km/Nvaq3dBeD+2sS9PQHY7rHfDve2YowxA4wxqcaY1PT09HIOQyS0lHQ1XjTHn+dlCdUox3b21n6Eh+Y/xCWtL2HDfRu47//uI8oU/6+qMkupiMqu9vF2KeN1kWBr7QRgAuSv4VvJ4xAJCl8VPvG1HQz7aLXXgJ9fpZNDXt2ZpLmmE58Tx9TeU7nttNswPj4dgMospWLKG/x3G2OaW2t3GWOaA3vc23cAx3vs1xLYWZEBilQXySvTOJKTW2y7I9rwZ1au18APcMRsoEHLt9m4dyN3dLyD1658jUa1G/n1nCqzlPIqb9onBejr/r4v8KnH9luNMTWNMW2AtsCyig1RJPQVpHT2H3EW2h4f66BOjRicruKB30Um+xzj+b3mI/yZ8ydzb5/L+73f9zvwi1REqVf+xphpwMVAI2PMDuBJIAn4yBiTCPwG9AGw1q43xnwEbABygUGq9JFI4O1GL0CdmjFe00CZUSv4w/EmeWYvV7W+mw9v/Tf1atariqGKAGCsj4+iValLly42NTU12MMQKbOCGbreAnwBw7EbX3kcZL9jIodjFhHjasntJ4/i3TvurpKxSvgxxqyw1nYpz7Fq7yBSTkVr+r0paLlgsRyJXsw+xwRc/Emc8xbicm9h2aZYklem+czba5UsCRQFf5Fy8pXqKRDriCbTmUcue9lXYwyZ0cvd6+g+Rw3bBih5BS+tkiWBpN4+IuVU0kzahPhYRt3Qgeh6X7Kz1r1kRa2hgfPvNMv+19HAX9rjqH2DBJKu/EXKyVdNf0J8LO/0b8kN0/qwNXcptVyn09B5Pw7b3OfjeKP2DRJIuvIXKSdvM2xrOSwnnvgFp409nU1713FczgM0yRnlM/ADXHJKY6/b1b5BAknBX6ScenVOKLQiV1z97eQ0HMH7G5OoR1eaZ42jbt4VGK8T349ZtMl7exO1b5BAUtpHxK08lTW9OifQ47SGPPXVU7z83cs0djVmxs0zGPZuTe99TbzwlcZR+wYJJAV/EfyvrHkseS0fLP2Nggm7rhrryK3/FrsO/0Ji50RGXz6aBrENeDl+YYm1/55KSuOofYMEitI+IvhXWfNY8lre/z4/8Ls4zB+ON9kePYL0Q1k8fe50Jl03iQax+esaeUvZOKINUUUyQI5oozSOBIWu/CXieEvv+FNZM21pfrfyI1Hfs6/GWPLIoL7zBuJy72DeDw144opjx3lL2VxySmM+XLYdl+es+uBPsJcIpeAvEcVXeicu1kFGprPY/p4pmRy7j32O8RyJ+QaHqzWNcx6jpj0Z8J63L5qy6Za0sFiDN6fL+pzkJRJICv4SUXyld2o5oo7OyC1QUFljreW91e+xs9ZgXGQR77yL+rk3Yjz++/hTfqm6fQklCv4SUXwF2owjTl69pVOxdFCnNk6unHol83+eT8s6nXHtG4DDHl/o2Cjyc/ylVQv5mhSmun0JBt3wlYhS0sSpXp0TWDLiUrYlXcPiRy7il+yP6TC2A99u/5Y3r3qT1y/7lPoxJxQ6LtYRxSu3dAIotERjWkYmQz9cRaen55O8Mg1Q3b6EFl35S9jwp05/eI92xTpxFg3A6/esJzElkaVpS7nqL1fx1rVv8cPWaPdxrkLHvdC7I706J9AtaaHXJm8Zmc5iJaOq25dQoOAvYcHfOv2SAnB2bjYvfPMCz3/9PPVr1uf9G97n9o638+mqnV7X3/XsyFlS3t5zP9XtS6hQ8Jew4OtG7tOz1nsN9EUD8Pc7vicxJZEN6Ru4vePtvNbjNRrXaXz0pOJr/d2CoO8rn190P5FQoZy/VGvJK9PoluR7Nu3+I85CefiRM9cezcED/JnzJ0M+G8J5k8/jUPYhZt82m6m9p9K4Tn6ztdJ69hfcQ/CWz/e2n0io0JW/BE1FV6nyZyWtojxTMPN/ns+AWQP49cCvDPq/Qbxw2QvF1tEt6Yrd815BwbifnrW+2CLuuqkroUjBX4KiMlapKu2q3JftGbvpm9yX91a/R7vj2vH1377m/Fbne93XVzon2pijN3sLFKSTSjupaWlGCQVawF2CwleqJiE+liUjLvXrMdqMmOOzO0KD2g6spdCs3fx1dL/mQI0J2Kg/ebTbozx24WPUiqnl8zm8fbrwrPIpq8p+PIlsFVnAXTl/CQpf6ZS0jMxCOXlfklemEWV898nPcrq49ozmR/PwuewlvcZz7K3xEq3iWpHaP5XnLn2uUOBPXplG52fm03rEHFqPmEOnp+cDFOrZnxAfW6FAraUZJVQo7SNBUVJ1TNH0T9E0ySWnNGbGijSfFTiQH1Bnr97FqBs6MPyzV/jNOQFj8ujb/jEm3fgkMVGF3/oFHTs9ZWQ6Gf7xakb3OcPvTyOlUYsHCRW68pegKKk6xvNKuCBN4lmxM/X73/zK9adn/cIT397M1tzXuLjN2fx4/3qm9Hm2WOBPXpnG1CKBv4DTZRn20WrajJhDt6SFfn0qKYmWZpRQoSt/CbiCK/e0jEyijSHPWhLiY7nxrIRiV9sF0jIyfd4XKO0ulSWPgzGfkBEzld1/1GBSz0n8vfPfMT7SRKPnbS7xMQs+YZTnpnRR/swwFqkKuuErAVVSOWasI5qaMVFeWykbytfqPsf8zB813iAn6mdi886lYc5ATohvWWJlTUk3jr0py01pb1TtI5WlIjd8FfwloEqagAX5VTlZTlehk0N5Ar+LbA7ETONgzEyiqE/DnHup4+pWbL+Cx07wCLqljdHbY2xLuqaMIxSpfKr2kZBV2o3MjCPOYtU0ZQ382VFryag7lIOO/1In7zJaZL3lNfDDsZOK52zf0mbnFqX8vIQDBX8JqNICpWcr5VfdrZFLEu2Rty9YR/f3miPJycvlqXOn0TR3KNHU9WtsnrN9i56AXrulE6/d0kktmCVs6YavBJS3G5wFPANpaa0aCiZCPfjhKgCORC11r6O7n3rOXjTIvZMnr7iRKQvnlGl8BZ9MSuq2qfy8hCMFfwkozxbKRat9PANpSa0aPPf9Z8o3/JI7hiMxX7vX0f0HNW074mMdQOndNYsq7ZOJWjBLuFLwl4DzJ4CWdG8gLSOTlz7fxKLt/2Uzj5MXnUmc8w7icm/CkB/0C7JB3j5pFNzkLXojWSkciWQK/hISSrpizzW7WZk5hu9W/kDNvFNp4ryfGrZVoX0y3J00S1qsRSWWIseo1FNCQvLKNIa68/kFLHkcip5NhuM/gCHeeTfxrmtx2eKTtSpaey9SHQWt1NMY84sxZq0xZpUxJtW9raEx5gtjzE/urw0q8hwSGXp1TjiatwfIMb/xe81H2F9jIjVdHWiRPYb6eT1xWaMKHJFKUBmlnpdYazt5nH1GAAustW2BBe6fRUr11HUdqOVwkRHzAbtqPkCu2cVxOcNokvMUMbYJcKyrZmV12RSJVIHI+V8PXOz+/l3gK+DRADyPVHNFc/AnH7+LX2P+SRa/Ujv3Iho6+xNN/NH9HVHG5xq8IlI2FQ3+FphvjLHAeGvtBKCptXYXgLV2lzGmSUUHKdWHr5uqvtoyZzrzcJHF2sMT+fanFKLtcTR2PkFtV9dij123VoyCvkglqWjw72at3ekO8F8YYzb5e6AxZgAwAKBVq1al7C2hwJ/lCb0tzZj6676jgb5g+9Tvf8MCmVEr+cPxJnlRu6mbezUNnH8litpen7+gokdEKq5Cwd9au9P9dY8x5hOgK7DbGNPcfdXfHNjj49gJwATIr/apyDgk8PxZc9fXKlXTlm4vtvBKLofY75jE4ZgFxLgSaJqdRC3XaSWOQT11RCpPuW/4GmPqGGPqFXwPXAGsA1KAvu7d+gKfVnSQEnz+LD/oa6KWZ+C3WA5HfcPOWvdyOHoR9Z030yL736UGflX0iFSuilz5NwU+cS+QEQN8YK393BizHPjIGJMI/Ab0qfgwJdj8WX7Q10StgpYOufzBvhrjyIz+nhqukzgu5xlq2BN9Pqe39ssiUjnKHfyttVuBM7xs/wO4rCKDktDjK7B7pmJ8rVLV+8zmvL1yMrujJgO5xDv/RlNzIzed3YpFm9IL3QT2/FkBXyRw1N5B/OLP8oNFWyvE13aQZdN4ZeUjZEWvpS6nE5c1mBPiTlJgFwkyBX/xS0k9c4ru16tzAjNW/Mq9nz7D3qj3IcpBw5zBNI66mqSbT1fQFwkBCv7iN38nV636fRV959zM4eifiM07h4Y59xLDcWTluY4uniIiwaXgL5UmKzeLZ/73DC8teQnrqkejnBHUdnXDcKwRW2nLOopI1VDwl0qx+NfF9J/Vnx//+JG/dvor6zdcz54sR7H9VKsvEhoU/MNMVfesP5h9kEe/eJS3VrxF6/jWzL9zPpefdDnJrYovy6hafZHQoeAfRkqahQul36wt64lj9o+zuXfOvaQdTOPBcx7k2UuepU6NOoD/N4hFJDi0mEsY6Za00GstfoPaDrKcrmJX4Z6tkL0toF50n4L9Rn3+PRuPvMHhmMW0qteOj25+l7Nbnu33OLWilkjlCNpiLhJafN1M3X/EWWprBn/aN3zyww7u++Q1fsj+G4ejvyXOeQex+0ezK72l32MsOMmkZWRiOfbpJHllmt+PISIVp+AfRsp6M9XzZFFa+4ZfM37l77N7syv6XzhcCTTPfoP43NvIckYVOkGUxp+TjIgEnoJ/GBneo53XJQ49l0f05Hmy8HXiiIuN4sRR99PmtVM44FpLg5wBNM15sdAC6mUp3/SnR5CIBJ6Cfxjp1TnB6xKHT13XodR1b72dOGz0dja7HmJb7pvUdLWnRfZY6uddh6HwfmX5xOFrX5WAilQtVfuEmZJm4Y6et5m0jEyijSmUavE8Jn+fg7jqfsqOvKkYW4vjch6iTt4lhSZrFShr+aY/PYJEJPAU/CNEQXAvaUGWXp0TaNEkjcSUIazbs47auRfQ0HlPoXV0CxgoV6WOSkBFQoOCfwQp6Wbr5R3ieXzR47y+9HWa121Oyq0pJH1Sx2vpaEJ8LEtGXFrucWgBdpHgU/APc5419b5mdPx88Ds6juvHtoxtDDxrIEndk4irFUdeD83SFQlXCv5hKnllGk+lrCcj0/ei53n86V5H90vaRrXlq75fcVHri47+XikakfCl4B+GvM3W9WSxHIlawr4ab+HiIL3/Moj3bx5NrKN4xY1SNCLhScE/hFRW2wNvuf0Cnuvo1jFteebC6Tx0cY+KDl1EqhkF/xBRUlM2XycAXycLbxOmLJY/o+ez3/E2xjh5sfuLPHTuQ8RE6S0gEon0Pz9ElFSJ4y34l3SyKLrYutPs5A/Hm2RHryHWduRfl47hvm4XBPCvEZFQpxm+IaKsbQ9KOlkUzNa15HEgZia7ag4mJ2oLLRnC+9fN4b4Ligf+5JVpdEtaSJsRc+iWtFCN1kTCnK78Q0TRq3XP7d6UdLIh+rItAAAMeElEQVTo1TmBbQc28Pji4Ry2P9Ig6jxGX/Y6ied57/xanpSTiFRvuvIPEd566xjgklMae93f10mhWVw0/1zwTx75+mrq1M7gw5s+5I/HvvEZ+EGdNkUikYJ/iOjVOYEbz0oo1D3HAjNWpHlNwXhtxObYyPYag3n+m+e5o+MdbLhvAzd3uBljivfk8aROmyKRR2mfELJoU3qxWbiZzjyGfbSaBz9cVaiix3MC1o6MP8ip8z6/u1JoHdOaeXfO44qTrvD7ecuachKR6k9X/iHE15V2nrVeV73q1TmBf/TOxNX0IXa7ZjH07KGsvXdtmQI/+F4HQG0cRMKXrvyDzLNWP8oY8kpZU7kgF9/t5BoM+XwI09ZNo0PjDnyc+DHntDzH6+OWNmFMbRxEIo8WcA+i0toweGOxHIn+Chs3hYPZB/nnBf9k5AUjqRFdo8TH9bYYu4hUbxVZwF1X/kHkqw1DtI9PALlmD384xpAVvYJzjjuHST0n0aFJB78et6QJYyISeRT8g8hXjt9lLQkeN2EtLg5FzyHD8R7g4u+nPcWEGx4jOira6/Gq3hGR0uiGbxCVtJ5twU1Yp9nO7hqPsr/GeGq6TuG+9slMvvFJn4G/tMcVEQEF/6Aqqcrm6tMbc+ZpC9lV6wGcUds5LudBTo5+ge5tO1bocUVEQGmfKuGr8sZXlU1Ck510mXAVa/espZ7rQuKy+xNNAw5k5vrVdkHVOyJSGlX7BFhZKm8O5xzmiUVP8NrS12hWtxn1M+8l81DnYo/puYZuZa0BICLVT0hW+xhjrgReB6KBSdbapEA9Vyjzt/JmwdYF9J/Vn20Z27jnrHt4sfuLdHrqG6+PmZaRSbekhaRlZGLg6KxgNWQTEX8FJOdvjIkGxgBXAe2B24wx7QPxXKGutMqb/Zn7Sfw0ke7/6U50VDRf9f2Kt659i7hacT5v0BrwqAQqTA3ZRMQfgbrh2xXYYq3daq3NAaYD1wfouUJaSZU3MzbMoP3Y9ry7+l0e7fYoawauKbSAuq9On6Ul6lTSKSKlCVTwTwC2e/y8w73tKGPMAGNMqjEmNT09PUDDCD5vAdzhOICr4b+46eObaF63Ocv6LyOpe1KxBdR7dU7ghd4dSYiPxZCf6/fnDo1KOkWkNIHK+XvrIVwobllrJwATIP+Gb4DGEXSelTdpGUeIqfcVu8wktu/JIumyJB469yEc0Y4Sj/fM3xfk+n1RSaeI+CNQwX8HcLzHzy2BnQF6rpDXq3MCHU/IYsDsASzctpALT7iQiT0ncvJxJxfbt7TqneE92hWrHipIBSWo2kdE/BSo4L8caGuMaQOkAbcCtwfouUJariuX179/nccXPY4j2sFb17xF/7P6E2WKZ9z8WU6xpBr+ghNH0d7/IiJFBST4W2tzjTGDgXnkl3q+ba1dH4jnCmVrdq8hMSWR1J2pXNfuOsZePZaE+r6Dsb9loUVTQaB1eEWkbALW3sFaO9dae7K19iRr7ahAPU8oysrN4rGFj3HWhLP4ce82To55nNWr+nPz2M1el2QsUJGGbFqHV0TKQu0dKtk3v31D/1n92bR3Exe3vIm0X/qQ7awDlH41XpHlFNXJU0TKQo3dKsmh7EMMnjuYC965gExnJp/f8Tk56feS4w78BUq6Gq9IQzZ18hSRslDwrwRzf5pLh7EdGLt8LA90fYB1962jx196lPlq3Ftdv7+rb6mTp4iUhdI+FZB+OJ2h84bywdoPaN+4PUv+voRzjz/36O/Lk8bxdjPXH+rkKSJloeBfDtZapq2bxpDPh3Ag6wBPXvQkI88fSc2YmoX281aTH8ir8fKeOEQk8ij4l9H2A9sZOGcgc3+aS9eErky+bjKnNTnN6766GheRUKXg7yeXdTFu+ThGLBiBy7p4tcer3N/1/hKXUwRdjYtIaKrWwb+qFjLZtHcT/VL6sWT7Ei4/8XLGXzueNg3aVPrziIhUlWob/KtiRqszz8lLS17imcXPUMdRhynXT+HuM+7GGG9960REqo9qG/z9bYVQXqk7U0lMSWTN7jX0ad+Hf1/1b5rWbVrhxxURCQXVNvgHakbrEecRnlj0BK9+/yrN6jYj+ZZkrj8lItehEZEwVm2Df0VaIfiycNtC+s/qz9b9Wxlw5gBevPxF4mvFV2SYIiIhqdrO8K3MGa0ZWRn0S+nHZe9dRpSJYlHfRYzvOd5r4E9emUa3pIW0GTGHbkkLS2zUJiISqqrtlX9l1dDP3DiTQXMHkX44nUfOe4SnLn6q2HKKBdQ2WUTCRbUN/lCxGvrf//ydwXMHM2PjDDo168Sc2+dwZvMzSzwm0DeZRUSqSrUO/uVhreWdVe8wbP4wMp2ZPH/p8zx83sMlrqNbQG2TRSRcRFTw37p/KwNmDWDBtgVc0OoCJvacSLtG/t8jCMRNZhGRYKi2N3zLIs+VxyvfvcJpY09jWdoyxl0zjq/++lWZAj+obbKIhI+wv/Jfu3stiSmJLN+5nGtPvpZx14yjZf2W5XosNWoTkXARtsE/Ozeb5xY/R9KSJBrUasC0G6dxS4dbfLZm8LdPkBq1iUg4CMvg/+32b+mX0o+Nezdy1+l38UqPV2hUu5HP/VXCKSKRJqxy/oeyD3H/3Ps5/+3zOew8zGd3fMZ7N7xXYuCHkks4RUTCUdhc+X/202cMnDOQ7Qe2c3Wbv7I3rTf3TsqjRfzCUvPyKuEUkUhT7a/89x7Zy12f3MXVH1xNHUcdnj//E37+6WZ+P2CwHEvhlNSGwVeppko4RSRcVevgv3THUk4dcyrT103n8QsfZ+U9K5m1vF6ZUzjeSjgNcMkpjQMxbBGRoKvWwf+URqfQ7fhu/DDgB5655BlqxtQsVwqnV+cEbjwrAc86IAvMWJGmxm0iEpaqdfCPqxVH8q3JdGza8ei28qZwFm1KxxbZppu+IhKuqnXw96a8s3B101dEIknYBf9enRN4oXdHEuJjMUBCfCwv9O5Yar2+bvqKSCQJm1JPT+WZhTu8R7tCE71AfXtEJHyFZfAvD/XtEZFIouDvQX17RCRShF3OX0RESqfgLyISgRT8RUQikIK/iEgEUvAXEYlAxtqiTQ2CMAhj0oFfgz2OIGkE7A32IEKIXo9j9FoUptejsEZAHWttuTpQhkTwj2TGmFRrbZdgjyNU6PU4Rq9FYXo9Cqvo66G0j4hIBFLwFxGJQAr+wTch2AMIMXo9jtFrUZhej8Iq9Hoo5y8iEoF05S8iEoEU/IPEGHOlMWazMWaLMWZEsMcTDMaYX4wxa40xq4wxqe5tDY0xXxhjfnJ/bRDscQaKMeZtY8weY8w6j20+/35jzEj3+2WzMaZHcEYdGD5ei6eMMWnu98cqY8zVHr8L29cCwBhzvDFmkTFmozFmvTFmiHt7pb0/FPyDwBgTDYwBrgLaA7cZY9oHd1RBc4m1tpNHydoIYIG1ti2wwP1zuJoCXFlkm9e/3/3+uBXo4D5mrPt9FC6mUPy1AHjV/f7oZK2dCxHxWgDkAsOstacC5wCD3H93pb0/FPyDoyuwxVq71VqbA0wHrg/ymELF9cC77u/fBXoFcSwBZa1dDOwrstnX3389MN1am22t3QZsIf99FBZ8vBa+hPVrAWCt3WWt/cH9/SFgI5BAJb4/FPyDIwHY7vHzDve2SGOB+caYFcaYAe5tTa21uyD/PwDQJGijCw5ff3+kvmcGG2PWuNNCBSmOiHotjDGtgc7AUirx/aHgHxzGy7ZILLvqZq09k/z01yBjzIXBHlAIi8T3zDjgJKATsAt42b09Yl4LY0xdYAYw1Fp7sKRdvWwr8TVR8A+OHcDxHj+3BHYGaSxBY63d6f66B/iE/I+pu40xzQHcX/cEb4RB4evvj7j3jLV2t7U2z1rrAiZyLI0REa+FMcZBfuCfaq2d6d5cae8PBf/gWA60Nca0McbUIP9GTUqQx1SljDF1jDH1Cr4HrgDWkf869HXv1hf4NDgjDBpff38KcKsxpqYxpg3QFlgWhPFVmYIg53YD+e8PiIDXwhhjgMnARmvtKx6/qrT3h9bwDQJrba4xZjAwD4gG3rbWrg/ysKpaU+CT/Pc4McAH1trPjTHLgY+MMYnAb0CfII4xoIwx04CLgUbGmB3Ak0ASXv5+a+16Y8xHwAbyK0EGWWvzgjLwAPDxWlxsjOlEfvriF+AeCP/Xwq0bcBew1hizyr3tH1Ti+0MzfEVEIpDSPiIiEUjBX0QkAin4i4hEIAV/EZEIpOAvIhKBFPxFRCKQgr+ISARS8BcRiUD/D1u5KDaEAEvzAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.scatter(myModel.test[myModel.target], testPred)\n", "plt.xlabel = 'True test values'\n", "plt.ylabel = 'Predicted test values'\n", "\n", "#plot line y = x\n", "x = np.arange(np.floor(myModel.test[myModel.target].min()), np.ceil(myModel.test[myModel.target].max()))\n", "plt.plot(x,x,color = 'green')\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If the points roughly follow the line y = x then that's an indication the model is working well enough" ] } ], "metadata": { "kernelspec": { "display_name": "Python (cgvae)", "language": "python", "name": "cgvae" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" } }, "nbformat": 4, "nbformat_minor": 2 }